{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tile Coding\n",
    "---\n",
    "\n",
    "Tile coding is an innovative way of discretizing a continuous space that enables better generalization compared to a single grid-based approach. The fundamental idea is to create several overlapping grids or _tilings_; then for any given sample value, you need only check which tiles it lies in. You can then encode the original continuous value by a vector of integer indices or bits that identifies each activated tile.\n",
    "\n",
    "### 1. Import the Necessary Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import common libraries\n",
    "import sys\n",
    "import gym\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Set plotting options\n",
    "%matplotlib inline\n",
    "plt.style.use('ggplot')\n",
    "np.set_printoptions(precision=3, linewidth=120)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Specify the Environment, and Explore the State and Action Spaces\n",
    "\n",
    "We'll use [OpenAI Gym](https://gym.openai.com/) environments to test and develop our algorithms. These simulate a variety of classic as well as contemporary reinforcement learning tasks.  Let's begin with an environment that has a continuous state space, but a discrete action space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create an environment\n",
    "env = gym.make('Acrobot-v1')\n",
    "env.seed(505);\n",
    "\n",
    "# Explore state (observation) space\n",
    "print(\"State space:\", env.observation_space)\n",
    "print(\"- low:\", env.observation_space.low)\n",
    "print(\"- high:\", env.observation_space.high)\n",
    "\n",
    "# Explore action space\n",
    "print(\"Action space:\", env.action_space)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the state space is multi-dimensional, with most dimensions ranging from -1 to 1 (positions of the two joints), while the final two dimensions have a larger range. How do we discretize such a space using tiles?\n",
    "\n",
    "### 3. Tiling\n",
    "\n",
    "Let's first design a way to create a single tiling for a given state space. This is very similar to a uniform grid! The only difference is that you should include an offset for each dimension that shifts the split points.\n",
    "\n",
    "For instance, if `low = [-1.0, -5.0]`, `high = [1.0, 5.0]`, `bins = (10, 10)`, and `offsets = (-0.1, 0.5)`, then return a list of 2 NumPy arrays (2 dimensions) each containing the following split points (9 split points per dimension):\n",
    "\n",
    "```\n",
    "[array([-0.9, -0.7, -0.5, -0.3, -0.1,  0.1,  0.3,  0.5,  0.7]),\n",
    " array([-3.5, -2.5, -1.5, -0.5,  0.5,  1.5,  2.5,  3.5,  4.5])]\n",
    "```\n",
    "\n",
    "Notice how the split points for the first dimension are offset by `-0.1`, and for the second dimension are offset by `+0.5`. This might mean that some of our tiles, especially along the perimeter, are partially outside the valid state space, but that is unavoidable and harmless."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_tiling_grid(low, high, bins=(10, 10), offsets=(0.0, 0.0)):\n",
    "    \"\"\"Define a uniformly-spaced grid that can be used for tile-coding a space.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    low : array_like\n",
    "        Lower bounds for each dimension of the continuous space.\n",
    "    high : array_like\n",
    "        Upper bounds for each dimension of the continuous space.\n",
    "    bins : tuple\n",
    "        Number of bins or tiles along each corresponding dimension.\n",
    "    offsets : tuple\n",
    "        Split points for each dimension should be offset by these values.\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    grid : list of array_like\n",
    "        A list of arrays containing split points for each dimension.\n",
    "    \"\"\"\n",
    "    # TODO: Implement this\n",
    "    pass\n",
    "\n",
    "\n",
    "low = [-1.0, -5.0]\n",
    "high = [1.0, 5.0]\n",
    "create_tiling_grid(low, high, bins=(10, 10), offsets=(-0.1, 0.5))  # [test]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can now use this function to define a set of tilings that are a little offset from each other."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_tilings(low, high, tiling_specs):\n",
    "    \"\"\"Define multiple tilings using the provided specifications.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    low : array_like\n",
    "        Lower bounds for each dimension of the continuous space.\n",
    "    high : array_like\n",
    "        Upper bounds for each dimension of the continuous space.\n",
    "    tiling_specs : list of tuples\n",
    "        A sequence of (bins, offsets) to be passed to create_tiling_grid().\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    tilings : list\n",
    "        A list of tilings (grids), each produced by create_tiling_grid().\n",
    "    \"\"\"\n",
    "    # TODO: Implement this\n",
    "    pass\n",
    "\n",
    "\n",
    "# Tiling specs: [(<bins>, <offsets>), ...]\n",
    "tiling_specs = [((10, 10), (-0.066, -0.33)),\n",
    "                ((10, 10), (0.0, 0.0)),\n",
    "                ((10, 10), (0.066, 0.33))]\n",
    "tilings = create_tilings(low, high, tiling_specs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It may be hard to gauge whether you are getting desired results or not. So let's try to visualize these tilings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.lines import Line2D\n",
    "\n",
    "def visualize_tilings(tilings):\n",
    "    \"\"\"Plot each tiling as a grid.\"\"\"\n",
    "    prop_cycle = plt.rcParams['axes.prop_cycle']\n",
    "    colors = prop_cycle.by_key()['color']\n",
    "    linestyles = ['-', '--', ':']\n",
    "    legend_lines = []\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(10, 10))\n",
    "    for i, grid in enumerate(tilings):\n",
    "        for x in grid[0]:\n",
    "            l = ax.axvline(x=x, color=colors[i % len(colors)], linestyle=linestyles[i % len(linestyles)], label=i)\n",
    "        for y in grid[1]:\n",
    "            l = ax.axhline(y=y, color=colors[i % len(colors)], linestyle=linestyles[i % len(linestyles)])\n",
    "        legend_lines.append(l)\n",
    "    ax.grid('off')\n",
    "    ax.legend(legend_lines, [\"Tiling #{}\".format(t) for t in range(len(legend_lines))], facecolor='white', framealpha=0.9)\n",
    "    ax.set_title(\"Tilings\")\n",
    "    return ax  # return Axis object to draw on later, if needed\n",
    "\n",
    "\n",
    "visualize_tilings(tilings);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great! Now that we have a way to generate these tilings, we can next write our encoding function that will convert any given continuous state value to a discrete vector.\n",
    "\n",
    "### 4. Tile Encoding\n",
    "\n",
    "Implement the following to produce a vector that contains the indices for each tile that the input state value belongs to. The shape of the vector can be the same as the arrangment of tiles you have, or it can be ultimately flattened for convenience.\n",
    "\n",
    "You can use the same `discretize()` function here from grid-based discretization, and simply call it for each tiling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def discretize(sample, grid):\n",
    "    \"\"\"Discretize a sample as per given grid.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    sample : array_like\n",
    "        A single sample from the (original) continuous space.\n",
    "    grid : list of array_like\n",
    "        A list of arrays containing split points for each dimension.\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    discretized_sample : array_like\n",
    "        A sequence of integers with the same number of dimensions as sample.\n",
    "    \"\"\"\n",
    "    # TODO: Implement this\n",
    "    pass\n",
    "\n",
    "\n",
    "def tile_encode(sample, tilings, flatten=False):\n",
    "    \"\"\"Encode given sample using tile-coding.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    sample : array_like\n",
    "        A single sample from the (original) continuous space.\n",
    "    tilings : list\n",
    "        A list of tilings (grids), each produced by create_tiling_grid().\n",
    "    flatten : bool\n",
    "        If true, flatten the resulting binary arrays into a single long vector.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    encoded_sample : list or array_like\n",
    "        A list of binary vectors, one for each tiling, or flattened into one.\n",
    "    \"\"\"\n",
    "    # TODO: Implement this\n",
    "    pass\n",
    "\n",
    "\n",
    "# Test with some sample values\n",
    "samples = [(-1.2 , -5.1 ),\n",
    "           (-0.75,  3.25),\n",
    "           (-0.5 ,  0.0 ),\n",
    "           ( 0.25, -1.9 ),\n",
    "           ( 0.15, -1.75),\n",
    "           ( 0.75,  2.5 ),\n",
    "           ( 0.7 , -3.7 ),\n",
    "           ( 1.0 ,  5.0 )]\n",
    "encoded_samples = [tile_encode(sample, tilings) for sample in samples]\n",
    "print(\"\\nSamples:\", repr(samples), sep=\"\\n\")\n",
    "print(\"\\nEncoded samples:\", repr(encoded_samples), sep=\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that we did not flatten the encoding above, which is why each sample's representation is a pair of indices for each tiling. This makes it easy to visualize it using the tilings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.patches import Rectangle\n",
    "\n",
    "def visualize_encoded_samples(samples, encoded_samples, tilings, low=None, high=None):\n",
    "    \"\"\"Visualize samples by activating the respective tiles.\"\"\"\n",
    "    samples = np.array(samples)  # for ease of indexing\n",
    "\n",
    "    # Show tiling grids\n",
    "    ax = visualize_tilings(tilings)\n",
    "    \n",
    "    # If bounds (low, high) are specified, use them to set axis limits\n",
    "    if low is not None and high is not None:\n",
    "        ax.set_xlim(low[0], high[0])\n",
    "        ax.set_ylim(low[1], high[1])\n",
    "    else:\n",
    "        # Pre-render (invisible) samples to automatically set reasonable axis limits, and use them as (low, high)\n",
    "        ax.plot(samples[:, 0], samples[:, 1], 'o', alpha=0.0)\n",
    "        low = [ax.get_xlim()[0], ax.get_ylim()[0]]\n",
    "        high = [ax.get_xlim()[1], ax.get_ylim()[1]]\n",
    "\n",
    "    # Map each encoded sample (which is really a list of indices) to the corresponding tiles it belongs to\n",
    "    tilings_extended = [np.hstack((np.array([low]).T, grid, np.array([high]).T)) for grid in tilings]  # add low and high ends\n",
    "    tile_centers = [(grid_extended[:, 1:] + grid_extended[:, :-1]) / 2 for grid_extended in tilings_extended]  # compute center of each tile\n",
    "    tile_toplefts = [grid_extended[:, :-1] for grid_extended in tilings_extended]  # compute topleft of each tile\n",
    "    tile_bottomrights = [grid_extended[:, 1:] for grid_extended in tilings_extended]  # compute bottomright of each tile\n",
    "\n",
    "    prop_cycle = plt.rcParams['axes.prop_cycle']\n",
    "    colors = prop_cycle.by_key()['color']\n",
    "    for sample, encoded_sample in zip(samples, encoded_samples):\n",
    "        for i, tile in enumerate(encoded_sample):\n",
    "            # Shade the entire tile with a rectangle\n",
    "            topleft = tile_toplefts[i][0][tile[0]], tile_toplefts[i][1][tile[1]]\n",
    "            bottomright = tile_bottomrights[i][0][tile[0]], tile_bottomrights[i][1][tile[1]]\n",
    "            ax.add_patch(Rectangle(topleft, bottomright[0] - topleft[0], bottomright[1] - topleft[1],\n",
    "                                   color=colors[i], alpha=0.33))\n",
    "\n",
    "            # In case sample is outside tile bounds, it may not have been highlighted properly\n",
    "            if any(sample < topleft) or any(sample > bottomright):\n",
    "                # So plot a point in the center of the tile and draw a connecting line\n",
    "                cx, cy = tile_centers[i][0][tile[0]], tile_centers[i][1][tile[1]]\n",
    "                ax.add_line(Line2D([sample[0], cx], [sample[1], cy], color=colors[i]))\n",
    "                ax.plot(cx, cy, 's', color=colors[i])\n",
    "    \n",
    "    # Finally, plot original samples\n",
    "    ax.plot(samples[:, 0], samples[:, 1], 'o', color='r')\n",
    "\n",
    "    ax.margins(x=0, y=0)  # remove unnecessary margins\n",
    "    ax.set_title(\"Tile-encoded samples\")\n",
    "    return ax\n",
    "\n",
    "visualize_encoded_samples(samples, encoded_samples, tilings);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inspect the results and make sure you understand how the corresponding tiles are being chosen. Note that some samples may have one or more tiles in common.\n",
    "\n",
    "### 5. Q-Table with Tile Coding\n",
    "\n",
    "The next step is to design a special Q-table that is able to utilize this tile coding scheme. It should have the same kind of interface as a regular table, i.e. given a `<state, action>` pair, it should return a `<value>`. Similarly, it should also allow you to update the `<value>` for a given `<state, action>` pair (note that this should update all the tiles that `<state>` belongs to).\n",
    "\n",
    "The `<state>` supplied here is assumed to be from the original continuous state space, and `<action>` is discrete (and integer index). The Q-table should internally convert the `<state>` to its tile-coded representation when required."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class QTable:\n",
    "    \"\"\"Simple Q-table.\"\"\"\n",
    "\n",
    "    def __init__(self, state_size, action_size):\n",
    "        \"\"\"Initialize Q-table.\n",
    "        \n",
    "        Parameters\n",
    "        ----------\n",
    "        state_size : tuple\n",
    "            Number of discrete values along each dimension of state space.\n",
    "        action_size : int\n",
    "            Number of discrete actions in action space.\n",
    "        \"\"\"\n",
    "        self.state_size = state_size\n",
    "        self.action_size = action_size\n",
    "\n",
    "        # TODO: Create Q-table, initialize all Q-values to zero\n",
    "        # Note: If state_size = (9, 9), action_size = 2, q_table.shape should be (9, 9, 2)\n",
    "        \n",
    "        print(\"QTable(): size =\", self.q_table.shape)\n",
    "\n",
    "\n",
    "class TiledQTable:\n",
    "    \"\"\"Composite Q-table with an internal tile coding scheme.\"\"\"\n",
    "    \n",
    "    def __init__(self, low, high, tiling_specs, action_size):\n",
    "        \"\"\"Create tilings and initialize internal Q-table(s).\n",
    "        \n",
    "        Parameters\n",
    "        ----------\n",
    "        low : array_like\n",
    "            Lower bounds for each dimension of state space.\n",
    "        high : array_like\n",
    "            Upper bounds for each dimension of state space.\n",
    "        tiling_specs : list of tuples\n",
    "            A sequence of (bins, offsets) to be passed to create_tilings() along with low, high.\n",
    "        action_size : int\n",
    "            Number of discrete actions in action space.\n",
    "        \"\"\"\n",
    "        self.tilings = create_tilings(low, high, tiling_specs)\n",
    "        self.state_sizes = [tuple(len(splits)+1 for splits in tiling_grid) for tiling_grid in self.tilings]\n",
    "        self.action_size = action_size\n",
    "        self.q_tables = [QTable(state_size, self.action_size) for state_size in self.state_sizes]\n",
    "        print(\"TiledQTable(): no. of internal tables = \", len(self.q_tables))\n",
    "    \n",
    "    def get(self, state, action):\n",
    "        \"\"\"Get Q-value for given <state, action> pair.\n",
    "        \n",
    "        Parameters\n",
    "        ----------\n",
    "        state : array_like\n",
    "            Vector representing the state in the original continuous space.\n",
    "        action : int\n",
    "            Index of desired action.\n",
    "        \n",
    "        Returns\n",
    "        -------\n",
    "        value : float\n",
    "            Q-value of given <state, action> pair, averaged from all internal Q-tables.\n",
    "        \"\"\"\n",
    "        # TODO: Encode state to get tile indices\n",
    "        \n",
    "        # TODO: Retrieve q-value for each tiling, and return their average\n",
    "        pass\n",
    "\n",
    "    def update(self, state, action, value, alpha=0.1):\n",
    "        \"\"\"Soft-update Q-value for given <state, action> pair to value.\n",
    "        \n",
    "        Instead of overwriting Q(state, action) with value, perform soft-update:\n",
    "            Q(state, action) = alpha * value + (1.0 - alpha) * Q(state, action)\n",
    "        \n",
    "        Parameters\n",
    "        ----------\n",
    "        state : array_like\n",
    "            Vector representing the state in the original continuous space.\n",
    "        action : int\n",
    "            Index of desired action.\n",
    "        value : float\n",
    "            Desired Q-value for <state, action> pair.\n",
    "        alpha : float\n",
    "            Update factor to perform soft-update, in [0.0, 1.0] range.\n",
    "        \"\"\"\n",
    "        # TODO: Encode state to get tile indices\n",
    "        \n",
    "        # TODO: Update q-value for each tiling by update factor alpha\n",
    "        pass\n",
    "\n",
    "\n",
    "# Test with a sample Q-table\n",
    "tq = TiledQTable(low, high, tiling_specs, 2)\n",
    "s1 = 3; s2 = 4; a = 0; q = 1.0\n",
    "print(\"[GET]    Q({}, {}) = {}\".format(samples[s1], a, tq.get(samples[s1], a)))  # check value at sample = s1, action = a\n",
    "print(\"[UPDATE] Q({}, {}) = {}\".format(samples[s2], a, q)); tq.update(samples[s2], a, q)  # update value for sample with some common tile(s)\n",
    "print(\"[GET]    Q({}, {}) = {}\".format(samples[s1], a, tq.get(samples[s1], a)))  # check value again, should be slightly updated"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you update the q-value for a particular state (say, `(0.25, -1.91)`) and action (say, `0`), then you should notice the q-value of a nearby state (e.g. `(0.15, -1.75)` and same action) has changed as well! This is how tile-coding is able to generalize values across the state space better than a single uniform grid."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6. Implement a Q-Learning Agent using Tile-Coding\n",
    "\n",
    "Now it's your turn to apply this discretization technique to design and test a complete learning agent! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
